import math
import numpy as np
import pandas as pd
from numpy import array
from typing import List, Dict, NewType
from dataclasses import dataclass
from lib.metric_manager import WEIGHT_METHODS, MetricManager
from lib.metric_type.metric_type import Level
from lib.score_result import LevelResults

TYPES = ["CapacityMetric", "LoadMetric",
         "LatencyMetric", "ErrorMetric"]

TYPEWEIGHTS = [0.2, 0.2, 0.3, 0.3]


@dataclass
class TypeWeight:
    type: str
    weight: float


# 用于存储每个level下每个metric type的权重
LevelTypeWeights = NewType('LevelTypeWeights',
                           Dict[Level, List[TypeWeight]])
# 用于存储每个metric type下每个metric的权重
TypeMetricWeights = NewType('TypeMetricWeights',
                            Dict[str, List[float]])


class WeightsCalculator:
    def __init__(self, metric_manager: MetricManager):
        self.metric_manager = metric_manager
        self.weights_method = WEIGHT_METHODS
        self.type_weights = self.setup_type_weights(metric_manager)

    def setup_type_weights(
        self,
        metric_manager: MetricManager
    ) -> LevelTypeWeights:

        type_weights = {}
        types = TYPES

        # 如果某个level没有注册某个metric type，其他metric的权重需要重新计算（等比例放大）
        for member in Level.__members__.values():
            # for now, skip cluster level
            if member == Level.Cluster:
                type_weights[member] = [
                    TypeWeight(type, weight)
                    for type, weight in zip(types, TYPEWEIGHTS)
                ]
                continue

            weights = TYPEWEIGHTS.copy()
            for type in types:
                if len(metric_manager.registed_metric[member][type]) <= 0:
                    missing_index = types.index(type)
                    weights[missing_index] = 0

            remaining_weights = sum(weights)
            adjusted_weights = [
                weight / remaining_weights for weight in weights
            ]
            combined = [
                TypeWeight(
                    type, weight) for type, weight in zip(
                    types, adjusted_weights)]

            type_weights[member] = combined

        return type_weights

    def cal_metric_weights(
        self,
        level: Level,
        data: List[LevelResults]
    ) -> TypeMetricWeights:
        """ Calculate weight of each metric

        Args:
            data: all pods/ins's LevelResults

        Returns:
            Dict of metric and its weight
        """

        res = TypeMetricWeights({})
        type_weights = self.type_weights[level]
        for type_weight in type_weights:
            rows = []
            for item in data:
                row = [sr["value"] for sr in item.results[type_weight.type]]
                rows.append(row)
            df = pd.DataFrame(rows)
            weights = self.cal_metric_weight_ewm(df)
            res[type_weight.type] = weights[0].tolist()

        return res

    def cal_metric_weight_ewm(self, x: pd.DataFrame) -> pd.DataFrame:
        '''熵值法计算变量的权重'''
        x = x.apply(lambda x: ((x - np.min(x)) / (np.max(x) - np.min(x))))

        rows = x.index.size
        cols = x.columns.size
        k = 1.0 / math.log(rows)

        lnf = [[None] * cols for i in range(rows)]

        x = array(x)
        lnf = [[None] * cols for i in range(rows)]
        lnf = array(lnf)
        for i in range(0, rows):
            for j in range(0, cols):
                if x[i][j] == 0:
                    lnfij = 0.0
                else:
                    p = x[i][j] / x.sum(axis=0)[j]
                    lnfij = math.log(p) * p * (-k)
                lnf[i][j] = lnfij
        lnf = pd.DataFrame(lnf)
        E = lnf

        d = 1 - E.sum(axis=0)

        w = [[None] * 1 for i in range(cols)]
        for j in range(0, cols):
            wj = d[j] / sum(d)
            w[j] = wj

        w = pd.DataFrame(w)
        return w

    def cal_weights_critic(self, df: pd.DataFrame) -> pd.DataFrame:
        X = df.values
        X_norm = (X - X.min(axis=0)) / (X.max(axis=0) - X.min(axis=0))

        sigma = np.std(X_norm, axis=0)
        corr = np.corrcoef(X_norm.T)
        C = sigma * np.sum(1 - corr, axis=0)
        weights = C / np.sum(C)

        return weights
